PyTorch tutorial: LSTM
『ゼロから作るDeep Learning②』を読んだ記憶
In the case of an LSTM, for each element in the sequence, there is a corresponding hidden state h_t, which in principle can contain information from arbitrary points earlier in the sequence.
Pytorch’s LSTM expects all of its inputs to be 3D tensors.
The first axis is the sequence itself, the second indexes instances in the mini-batch, and the third indexes elements of the input.
1st, 2nd axis (dimension) はサイズ1とする
2nd: ミニバッチは考えない(?)
1st: 一度に1系列進める
>> import torch
>> import torch.nn as nn
>> torch.manual_seed(1)
>> lstm = nn.LSTM(3, 3)
>> hidden = (torch.randn(1, 1, 3), torch.randn(1, 1, 3))
>> for i in inputs:
... print(i)
... out, hidden = lstm(i.view(1, 1, -1), hidden)
... print(out)
... print(hidden)
tensor(-0.5525, 0.6355, -0.3968)
tensor(-0.6571, -1.6428, 0.9803)
tensor(-0.0421, -0.8206, 0.3133)
tensor(-1.1352, 0.3773, -0.2824)
tensor(-2.5667, -1.4303, 0.5009)
>> inputs =, 1, -1) # Add the extra 2nd dimension
>> inputs.size()
>> hidden = (torch.randn(1, 1, 3), torch.randn(1, 1, 3))
>> out, hidden = lstm(inputs, hidden)
>> print(out)
tensor([-0.2945, -0.3090, 0.0366,
-0.5580, -0.1228, 0.0714,
-0.4122, -0.0834, 0.0380,
-0.1954, -0.0010, 0.0192,
-0.3722, 0.0672, 0.1393], grad_fn=<StackBackward0>)
>> print(hidden)
the first value returned by LSTM is all of the hidden states throughout the sequence.
outはすべてのhidden stateである
the second is just the most recent hidden state (compare the last slice of "out" with "hidden" below, they are the same)
hiddenは最近のhidden stateである
"out" will give you access to all hidden states in the sequence
"hidden" will allow you to continue the sequence and backpropagate, by passing it as an argument to the lstm at a later time